8000 [typing] Add static type hints to `torch.distributions`. · Issue #144196 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[typing] Add static type hints to torch.distributions. #144196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
3 of 6 tasks
randolf-scholz opened this issue Jan 4, 2025 · 7 comments · May be fixed by #144219, #154711 or #154827
Closed
3 of 6 tasks

[typing] Add static type hints to torch.distributions. #144196

randolf-scholz opened this issue Jan 4, 2025 · 7 comments · May be fixed by #144219, #154711 or #154827
Labels
module: distributions Related to torch.distributions module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@randolf-scholz
Copy link
Contributor
randolf-scholz commented Jan 4, 2025

🚀 The feature, motivation and pitch

Current lack of type hints causes some issues, for instance #76772

Alternatives

No response

Additional context

No response

cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster

@cpuhrsch cpuhrsch added module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 7, 2025
@cpuhrsch
Copy link
Contributor
cpuhrsch commented Jan 7, 2025

Adding this for triage review to raise awareness for potential reviewers. Some of the PRs associated to this don't have immediately obvious reviewers in my view.

@cpuhrsch cpuhrsch added triage review and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 7, 2025
pytorchmergebot pushed a commit that referenced this issue Jan 7, 2025
….distributions`. (#144110)

Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: #144110
Approved by: https://github.com/Skylion007
@malfet
Copy link
Contributor
malfet commented Jan 8, 2025

My concern about those typing PRs:

  • They still do not eliminate allow untyped suppression, see
    # mypy: allow-untyped-defs
  • Lots of them add following import
    from torch import Tensor

    In the past, I've seen anti-patterns when devs used torch.foo.bar.Tensor instead of torch.Tensor and such imports encourage it. If we going to discuss this one, perhaps good to have an agreement, whether it's OK to do something like that for the sake of type annotation or not
    [edit] But I see this is already a pattern in torch.distributions., so probably OK to keep using it

@randolf-scholz
Copy link
Contributor Author
randolf-scholz commented Jan 8, 2025

@malfet The PRs are incremental,

  1. [typing] Add type hints to @property and @lazy_property in torch.distributions. #144110 Adds typing to lazy_property class and annotates returns of @property and @lazy_property annotated methods. This one adds a static test and is merged already.
  2. [typing] Add type hints to __init__ methods in torch.distributions. #144197 Adds typing to all __init__ methods. Not merged yet and there is a blocker.
  3. Full static typing for torch.distributions #144219 implements full static typing and eliminates all # mypy: allow-untyped-defs. However, in doing the annotations, mypy detected several LSP-violations that may need to be addressed. Currently, I just added a type: ignore.

@randolf-scholz
Copy link
Contributor Author
randolf-scholz commented Jan 8, 2025
  • In the past, I've seen anti-patterns when devs used torch.foo.bar.Tensor instead of torch.Tensor and such imports encourage it. If we going to discuss this one, perhaps good to have an agreement, whether it's OK to do something like that for the sake of type annotation or not

I don't quite follow. Do you mean that the issue is that people attempt to annotate like torch.cuda.Tensor instead of torch.Tensor? This annotation already gets used a lot: https://github.com/search?q=repo%3Apytorch%2Fpytorch+%2F%28%3Fi%29-%3E+Tensor%3A%2F+language%3APython&type=code&l=Python

@malfet
Copy link
Contributor
malfet commented Jan 8, 2025

I don't quite follow. Do you mean that the issue is that people attempt to annotate like torch.cuda.Tensor instead of torch.Tensor?

Yes. But if this is already a common pattern, than perhaps it's fine

@malfet malfet added the module: distributions Related to torch.distributions label Jan 13, 2025
@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jan 13, 2025
@Skylion007
Copy link
Collaborator

What's the status on the typing PRs? What needs to be reviewed?

@randolf-scholz
Copy link
Contributor Author

@Skylion007 Logically, the next step would be the __init__-PR #144197. What's currently still missing is:

  1. [typing] Add type hints to __init__ methods in torch.distributions. #144197 (comment)
  2. Some static tests. Generally what would be useful here is to simply type check the regular test suite, since this will cover the most important use cases. I have a branch on my fork for this: https://github.com/randolf-scholz/pytorch/tree/distribution_type_hint_test

timocafe pushed a commit to timocafe/pytorch that referenced this issue Apr 16, 2025
…`. (pytorch#144197)

Fixes pytorch#144196
Extends pytorch#144106 and pytorch#144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. pytorch#144197 (comment)

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ pytorch#144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. pytorch#144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ pytorch#144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ pytorch#144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <python/mypy#3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: pytorch#144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
amathewc pushed a commit to amathewc/pytorch that referenced this issue Apr 17, 2025
…`. (pytorch#144197)

Fixes pytorch#144196
Extends pytorch#144106 and pytorch#144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. pytorch#144197 (comment)

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ pytorch#144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. pytorch#144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ pytorch#144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ pytorch#144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <python/mypy#3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: pytorch#144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
pytorchmergebot pushed a commit that referenced this issue May 30, 2025
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this issue Jun 2, 2025
qingyi-yan pushed a commit to qingyi-yan/pytorch that referenced this issue Jun 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributions Related to torch.distributions module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
5 participants
0