8000 [typing] Add type hints to `__init__` methods in `torch.distributions`. by randolf-scholz · Pull Request #144197 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[typing] Add type hints to __init__ methods in torch.distributions. #144197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

randolf-scholz
Copy link
Contributor
@randolf-scholz randolf-scholz commented Jan 4, 2025

Fixes #144196
Extends #144106 and #144110

Open Problems:

Notes

  • beta.py: needed to add type: ignore since broadcast_all is untyped.
  • categorical.py: converted else branches of mutually exclusive arguments to if branch1.
  • dirichlet.py: replaced axis with dim arguments. Dirichlet.mode: use dim= instead of axis= #144402
  • gemoetric.py: converted else branches of mutually exclusive arguments to if branch1.
  • independent.py: fixed bug in Independent.__init__ where tuple[int, ...] could be passed to Distribution.__init__ instead of torch.Size. EDIT: turns out the bug is related to typing of torch.Size. Improve static typing for torch.Size #144218
  • independent.py: made Independent a generic class of its base distribution.
  • multivariate_normal.py: converted else branches of mutually exclusive arguments to if branch1.
  • relaxed_bernoulli.py: added class-level type hint for base_dist.
  • relaxed_categorical.py: added class-level type hint for base_dist.
  • transforms.py: Added missing argument to docstring of ReshapeTransform ReshapeTransform: added missing argument in docstring #144401
  • transforms.py: Fixed bug in AffineTransform.sign (could return Tensor instead of int). Fix AffineTransform.sign #144400
  • transforms.py: Added type: ignore comments to AffineTransform.log_abs_det_jacobian2; replaced torch.abs(scale) with scale.abs().
  • transforms.py: Added type: ignore comments to AffineTransform.__eq__2.
  • transforms.py: Fixed type hint on CumulativeDistributionTransform.domain. Note that this is still an LSP violation, because Transform.domain is defined as Constraint, but Distribution.domain is defined as Optional[Constraint].
  • skipped: constraints.py, constraints_registry.py, kl.py, utils.py, exp_family.py, __init__.py.

Remark

TransformedDistribution: __init__ uses the check if reinterpreted_batch_ndims > 0:, which can lead to the creation of Independent distributions with only 1 component. This results in awkward code like base_dist.base_dist in LogisticNormal.

import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal

One could consider changing this to if reinterpreted_batch_ndims > 1:.

cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster

Footnotes

  1. Otherwise, we would have to add a bunch of type: ignore comments to make mypy happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. 2 3

  2. Usage of isinstance(value, numbers.Real) leads to problems with static typing, as the numbers module is not supported by mypy (see https://github.com/python/mypy/issues/3186). This results in us having to add type-ignore comments in several places 2

Copy link
pytorch-bot bot commented Jan 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144197

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit aafae2d with merge base 87a63a9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@randolf-scholz
Copy link
Contributor Author

@pytorchbot label "module: typing"
@pytorchbot label "module: distributions"

@pytorch-bot pytorch-bot bot added the module: typing Related to mypy type annotations label Jan 4, 2025
@randolf-scholz
Copy link
Contributor Author

@pytorchbot label "module: distributions"

@pytorch-bot pytorch-bot bot added the module: distributions Related to torch.distributions label Jan 4, 2025
@randolf-scholz
Copy link
Contributor Author

@pytorchbot label "release notes: python_frontend"

@pytorch-bot pytorch-bot bot added the release notes: python_frontend python frontend release notes category label Jan 4, 2025
@malfet
Copy link
Contributor
malfet commented Jan 15, 2025

@randolf-scholz but I feel like 50+% of this change could be landed separately even with mypy suppression for the whole file, say something like torch/distributions/weibull.py which will make incremental reviews faster.

@randolf-scholz
Copy link
Contributor Author

@randolf-scholz but I feel like 50+% of this change could be landed separately even with mypy suppression for the whole file, say something like torch/distributions/weibull.py which will make incremental reviews faster.

One can do that, but I think it would make most sense to first fix the numbers.Number issue and then update the annotations after that. Otherwise, we may end up with incorrect annotations or having to update them again.

@randolf-scholz
Copy link
Contributor Author

I added annotations to broadcast_all in this PR, since this function is used in virtually all __init__ methods, this required adding a few assert statements to make mypy happy when mutually exclusive arguments are used.

def broadcast_all(*values):
# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added.
# See https://github.com/python/typing/issues/1216#issuecomment-2126153831
def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the annotation does not cover the case of __torch_function__.

This is somewhat outside the scope of this PR, I think possibly a Protocol type should be added to torch.types that covers the concrete semantics.

@randolf-scholz randolf-scholz force-pushed the distributions_type_hint_init branch from fd46589 to 212325e Compare January 28, 2025 16:11
@randolf-scholz randolf-scholz requested a review from malfet January 28, 2025 16:32
@randolf-scholz
Copy link
Contributor Author

One nice thing to also add would be overloads for __init__ methods with mutually exclusive arguments like

def __init__(self, probs=None, logits=None, validate_args=None):

However, the current API does not make use of keyword-only arguments for some reason. The following would be much nicer, imo:

@overload
def __init__(self, probs: Tensor, *, validate_args: Optional[bool]=None) -> None: ...
@overload
def __init__(self, *, logits: Tensor, validate_args: Optional[bool]=None) -> None: ...
def __init__(self, probs: Optional[Tensor]=None, *, logits: Optional[Tensor]=None, validate_args: Optional[bool]=None) -> None: 

However, this would require some backward compatibility breaking changes, for instance the argument order in Binomial.__init__ would need to be changed. But maybe a possibility in the future.

Copy link
Contributor
github-actions bot commented Apr 5, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a 9E88 mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Apr 5, 2025
@Skylion007
Copy link
Collaborator
Skylion007 commented Apr 6, 2025

@randolf-scholz Mind doing a rebase? @malfet Think this is blocked by your review

@malfet
Copy link
Contributor
malfet commented Apr 6, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/144197/head returned non-zero exit code 1

Rebasing (1/31)
Auto-merging torch/distributions/utils.py
CONFLICT (content): Merge conflict in torch/distributions/utils.py
error: could not apply ef2329f5cc9... added type hints to lazy_property
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply ef2329f5cc9... added type hints to lazy_property

Raised by https://github.com/pytorch/pytorch/actions/runs/14295053243

@malfet
Copy link
Contributor
malfet commented Apr 6, 2025

@pytorchbot merge -f "Rebase fails, let's see if merge will be fine"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
…`. (pytorch#144197)

Fixes pytorch#144196
Extends pytorch#144106 and pytorch#144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. pytorch#144197 (comment)

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ pytorch#144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. pytorch#144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ pytorch#144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ pytorch#144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <python/mypy#3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: pytorch#144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…`. (pytorch#144197)

Fixes pytorch#144196
Extends pytorch#144106 and pytorch#144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. pytorch#144197 (comment)

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`
9A82
dirichlet.py`: replaced `axis` with `dim` arguments.~~ pytorch#144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. pytorch#144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ pytorch#144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ pytorch#144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <python/mypy#3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: pytorch#144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: distributions Related to torch.distributions module: typing Related to mypy type annotations open source release notes: python_frontend python frontend release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[typing] Add static type hints to torch.distributions.
7 participants
0